/*    Copyright 2010 Tobias Marschall
 *
 *    This file is part of MoSDi.
 *
 *    MoSDi is free software: you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation, either version 3 of the License, or
 *    (at your option) any later version.
 *
 *    MoSDi is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with MoSDi.  If not, see <http://www.gnu.org/licenses/>.
 */

package mosdi.paa;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Map;
import java.util.Queue;
import java.util.TreeMap;

import mosdi.fa.CharacterAutomaton;
import mosdi.fa.FiniteMemoryTextModel;

/** Markov chain resulting from taking the product of a CDFA and a
 *  finite-memory text model. For the initial probability distribution, 
 *  an equilibrium of the text model is assumed. */
public class ProductMarkovChain implements MarkovChain {
	private CharacterAutomaton automaton;
	private FiniteMemoryTextModel textModel;
	private int alphabetSize;
	private ProductState[] productStates;
	private ArrayList<int[]> targets;
	private ArrayList<double[]> targetProbabilities;
	private int[][] preimages;
	private double[][] preimagesProbabilities;

	private static class ProductState {
		int automatonState;
		int textModelState;
		ProductState(int automatonState, int textModelState) {
			this.automatonState = automatonState;
			this.textModelState = textModelState;
		}
		@Override
		public boolean equals(Object obj) {
			if (!(obj instanceof ProductState)) return false;
			ProductState p = (ProductState)obj; 
			return (p.automatonState==automatonState) && (p.textModelState==textModelState);
		}
		@Override
		public int hashCode() {
			return textModelState ^ Integer.reverse(automatonState);
		}
		@Override
		public String toString() {
			return "("+automatonState+","+textModelState+")";
		}
	}

	public ProductMarkovChain(CharacterAutomaton automaton, FiniteMemoryTextModel textModel) {
		if (automaton.getAlphabetSize()!=textModel.getAlphabetSize()) throw new IllegalArgumentException("Alphabet size mismatch.");
		this.alphabetSize = automaton.getAlphabetSize();
		this.textModel = textModel;
		this.automaton = automaton;
		// maps product state on their new index.
		Map<ProductState,Integer> states = new HashMap<ProductState,Integer>();
		Queue <ProductState> queue = new LinkedList<ProductState>();
		targets = new ArrayList<int[]>();
		targetProbabilities = new ArrayList<double[]>();
		int n = 0;
		for (int i=0; i<textModel.getStateCount(); ++i) {
			ProductState newState = new ProductState(automaton.getStartState(),i);
			states.put(newState, n++);
			queue.add(newState);
		}
		while (!queue.isEmpty()) {
			ProductState state = queue.remove();
			// maps target states to probabilities
			Map<Integer,Double> transitionMap = new TreeMap<Integer,Double>();
			for (int c=0; c<alphabetSize; ++c) {
				int automatonTargetState = automaton.getTransitionTarget(state.automatonState, c);
				for (int textModelTargetState : textModel.getTransitionTargets(state.textModelState, c)) {
					ProductState targetState = new ProductState(automatonTargetState, textModelTargetState);
					int targetStateIndex;
					if (states.containsKey(targetState)) {
						targetStateIndex = states.get(targetState);
					} else {
						targetStateIndex = n++;
						states.put(targetState,targetStateIndex);
						queue.add(targetState);
					}
					double p = 0.0;
					if (transitionMap.containsKey(targetStateIndex)) {
						p = transitionMap.get(targetStateIndex);
					}
					transitionMap.put(targetStateIndex, p+textModel.getProbability(state.textModelState, c, textModelTargetState));
				}
			}
			int[] targetArray = new int[transitionMap.size()];
			double[] targetProbabilitiesArray = new double[transitionMap.size()];
			int i = 0;
			for (Map.Entry<Integer,Double> e : transitionMap.entrySet()) {
				targetArray[i] = e.getKey();
				targetProbabilitiesArray[i] = e.getValue();
				i+=1;
			}
			targets.add(targetArray);
			targetProbabilities.add(targetProbabilitiesArray);
		}
		productStates = new ProductState[states.size()];
		for (Map.Entry<ProductState,Integer> e : states.entrySet()) {
			productStates[e.getValue()] = e.getKey();
		}
	}

	/** The given state corresponds to a pair (automatonState, textModelState), whose
	 *  first component is returned by this function. */
	public int getAutomatonState(int state) {
		return productStates[state].automatonState;
	}

	/** The given state corresponds to a pair (automatonState, textModelState), whose
	 *  second component is returned by this function. */
	public int getTextModelState(int state) {
		return productStates[state].textModelState;
	}

	public int[] getTargets(int state) {
		return targets.get(state);
	}
	
	public double[] getTargetProbabilities(int state) {
		return targetProbabilities.get(state);
	}
	
	private void precomputePreimages() {
		// Step 1) compute sizes of preimages of each state
		int[] preimageSizes = new int[productStates.length];
		for (int[] targetStates : targets) {
			for (int targetState : targetStates) preimageSizes[targetState]+=1;
		}
		// Step 2) compute preimages
		preimages = new int[productStates.length][];
		preimagesProbabilities = new double[productStates.length][];
		for (int i=0; i<preimages.length; ++i) {
			preimages[i] = new int[preimageSizes[i]];
			preimagesProbabilities[i] = new double[preimageSizes[i]];
		}
		Arrays.fill(preimageSizes, 0);
		for (int state=0; state<preimages.length; ++state) {
			int[] targetStates = targets.get(state);
			double[] targetProbs = targetProbabilities.get(state);
			for (int j=0; j<targetStates.length; ++j) {
				int targetState = targetStates[j]; 
				int k = preimageSizes[targetState];
				preimages[targetState][k] = state;
				preimagesProbabilities[targetState][k] = targetProbs[j];
				preimageSizes[targetState]+=1;
			}
		}
	}
	
	public int[] getPreimage(int state) {
		if (preimages==null) precomputePreimages();
		return preimages[state];
	}

	public double[] getPreimageProbabilities(int state) {
		if (preimagesProbabilities==null) precomputePreimages();
		return preimagesProbabilities[state];
	}

	@Override
	public double getInitialProbability(int state) {
		ProductState ps = productStates[state]; 
		if (ps.automatonState == automaton.getStartState()) {
			return textModel.getEquilibriumProbability(ps.textModelState); 
		} else {
			return 0;
		}
	}

	@Override
	public int getStateCount() {
		return productStates.length;
	}

	@Override
	public double getTransitionProbability(int state, int targetState) {
		int[] t = targets.get(state);
		for (int i=0; i<t.length; ++i) {
			if (t[i]==targetState) return targetProbabilities.get(state)[i];
		}
		return 0.0;
	}
	
}
